# Copyright (c) 2024-present AI-Labs

from fastapi import APIRouter, Request, UploadFile, Form

import os, subprocess
import uuid
from datetime import datetime
import torch
from openai import OpenAI
import base64

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

from configs import config

"""
路由信息设置
"""
router = APIRouter(
    prefix='/vision_language/qwen2vl',
    tags = ['图片识别']
)

# GiteeAI 平台部署加速
# model_path = 'Qwen/Qwen2-VL-2B-Instruct'

# 下载到本地
# model_path = 'models/Qwen/Qwen2-VL-2B-Instruct'

# 使用配置文件
model_path = config.service.qwen2vl.model_path
device = config.service.qwen2vl.device
image_limit = config.service.qwen2vl.image_limit

model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
model = model.to(device=device, dtype=torch.float16)
model.eval()


"""
视觉语言多模态数据处理，支持：图片内容识别问答、视频内容识别问答
"""
@router.post("/v1/chat")
async def chat_with_media_v1(request: Request):
    # 获取用户请求数据
    data = await request.json()
    # 指定语言模型参数
    params = {
        'top_p': 0.8,
        'top_k': 100,
        'temperature': 0.7,
        'repetition_penalty': 1.05,
        "max_new_tokens": 8192
    }

    # 先将用户提交的多媒体材料保存到本地
    localdir = f"statics/upload/{datetime.now().strftime('%Y-%m-%d')}/{uuid.uuid4()}"
    os.makedirs(localdir, exist_ok=True)
    localfile = f"{localdir}/{data['filename']}"

    media_type = data['media_type']
    with open(localfile, "wb") as f:
        f.write(base64.b64decode(data['media']))
    
    # 根据多媒体类型进行不同的处理
    if media_type == "image":
        # 如果用户提交的是图片，则直接组织对话消息
        messages = [
        {
                "role": "user",
                "content": [
                    {
                        "type": media_type,
                        media_type: localfile,
                    },
                    {"type": "text", "text": data['text']},
                ],
            }
        ]
    if media_type == "video":
        # 如果用户提交的是视频，则需要抽取视频中的帧，根据配置文件中的帧数抽取图片
        imagedir = f"{localdir}/images"
        os.makedirs(imagedir, exist_ok=True)
        subprocess.getoutput(f"ffmpeg -i {localfile} -r 1 -t {image_limit} {imagedir}/image-%3d.jpg")
        # 抽取的图片列表
        imagelist = [f"{imagedir}/{image}" for image in os.listdir(imagedir)]
        # 根据图片列表组织对话消息
        messages = [
        {
                "role": "user",
                "content": [
                    {
                        "type": media_type,
                        media_type: imagelist,
                    },
                    {"type": "text", "text": data['text']},
                ],
            }
        ]
    # 对话模板设置
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    # 处理输入信息
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)
    
    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=8192)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    # 处理输出信息
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    # 得到对话结果
    answer = output_text[0]

    # 清理显存缓存
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

    # 返回结果
    return {"text": answer}


@router.post("/v1/marketing")
def marketing_documents(image: UploadFile,
                        product_name=Form(None),
                        product_tags=Form(None),
                        product_gender=Form(None),
                        product_season=Form(None),
                        product_price=Form(None),
                        product_style=Form(None),
                        product_material=Form(None),
                        product_advantage=Form(None),
                        product_description=Form(None)
                        ):
    image_description = chat_with_image(image, "请详细描述这幅图片，精准捕获图片中服装的每一个细节")

    messages = [{
                "role": "system",
                "content": "你是一位优秀的服装商品智能销售专家，你需要推销一件服装，你需要放大商品的优点，激发用户的购买欲望！"
            }, {
                "role": "system",
                "content": image_description
            }, {
                "role": "user",
                "content": """这件服装商品的详细信息如下：
商品名称：{product_name}
商品标签：{product_tags}
商品类型：{product_gender}
适合季节：{product_season}
商品价格：{product_price}
设计风格：{product_style}
服装材质：{product_material}
商品描述：{product_description}。

请写一段电商平台的种草文案。""".format(
            product_name = product_name,
            product_tags = product_tags,
            product_gender = product_gender,
            product_season = product_season,
            product_price = product_price,
            product_style = product_style,
            product_material = product_material,
            product_advantage = product_advantage,
            product_description = product_description
        )
            }]

    client = OpenAI(api_key=os.getenv(config.service.chat.api_key_env, default=config.service.chat.api_key_env), base_url=config.service.chat.base_url)

    answer = client.chat.completions.create(
        model="chat",
        messages=messages,
        stream=False,
        max_tokens=4096,
        temperature=0.7,
        presence_penalty=1.2,
        top_p=0.8,
    ).choices[0].message.content

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

    return image_description + "\n\n" + answer
